import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from IPy import IP
from global_var import *
from sklearn import metrics
from time import *
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OrdinalEncoder
from sklearn.utils import resample
from sklearn.preprocessing import MinMaxScaler


def load_data(dataset, subset, mode='train', **kwargs):
    if dataset == 'cicids':
        X, y = load_cicids(subset)

    elif dataset == 'cicids_custom':
        X, y = load_cicids_custom(subset)
    elif dataset == 'cicids_custom1':
        X, y = load_cicids_custom(subset)
    elif dataset == 'toniot_custom':
        X, y = load_toniot_custom(subset)
    elif dataset == 'cicids_improved':
        X, y = load_cicids_improved(subset, **kwargs)
    elif dataset == 'cse_improved':
        X, y = load_cse_improved(subset, **kwargs)
    elif dataset == 'KDD':
        X,y = load_KDD(subset)
    elif dataset == 'cidds':
        X, y = load_cidds(subset)
    elif dataset == 'CICDDOS':
        X, y = load_CICDDOS(subset)
    elif dataset == 'Web':
        X,y = load_web(subset)
    elif dataset == 'CICIDS_imporved':
        X, y = load_cicids_improved(subset)
    else:
        print('no such dataset')
        exit()
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=SEED)

    if mode == 'train':
        X_train, X_eval, y_train, y_eval = train_test_split(X_train, y_train, test_size=0.25, random_state=SEED)

        X_train, y_train = X_train[y_train == 0], y_train[y_train == 0]
        if 'random_select' in kwargs:
            X_train = X_train[y_train == 0]
            idx_rand = np.random.randint(0, X_train.shape[0], kwargs['random_select'])
            X_train, y_train = X_train[idx_rand], y_train[idx_rand]

        return X_train, X_eval, y_train.astype(int), y_eval.astype(int)
    elif mode == 'test':
        return X_test, y_test




def load_web(subset):
    df = pd.read_csv(os.path.join('../dataset', 'Web', f'{subset}.csv'), encoding='utf-8', on_bad_lines='skip',
                       nrows=None)
    df.dropna(how='any', inplace=True)


    X = df[['url_cluster', 'url_len', 'geo_loc', 'tld', 'who_is', 'https', 'content_cluster', 'js_len', 'js_obf_len']]
    y = df['label']

   
    indices = np.random.choice(len(X), size=10000, replace=False)
    X = np.array(X)[indices]
    y = np.array(y)[indices]
    y = 1-y
    return X, y









def load_CICDDOS(subset):
    data = pd.read_csv(os.path.join('../dataset', 'CICDDOS-2019', f'{subset}.csv'), header=None)
    data.dropna(how='any', inplace=True)
    # n_y1 = 1000  # 假设 Y=1 的数量是 Y=0 的数量的10倍
    # n_y0 = 9000
    #
    # # 随机抽取 Y=1 和 Y=0 对应数量的数据
    # df_sample_y1 = data[data.iloc[:, -1] == 1].sample(n=n_y1, replace=True)
    # df_sample_y0 = data[data.iloc[:, -1] == 0].sample(n=n_y0, replace=True)
    # df_sample = pd.concat([df_sample_y1, df_sample_y0])
    #
    # # 对数据进行混洗
    # data = df_sample.sample(frac=1).reset_index(drop=True)
    # X = data.iloc[:, :-1].values
    # y = data.iloc[:, -1].values
    # return X, y
    indices = np.random.choice(len(data), size=20000, replace=False)
    data = np.array(data)[indices]
    # print(data.columns)
    # n_y1 = 1000  # 假设 Y=1 的数量是 Y=0 的数量的10倍
    # n_y0 = 9000
    #
    # # 随机抽取 Y=1 和 Y=0 对应数量的数据
    # df_sample_y1 = data[data.iloc[:,-1] == 1].sample(n=n_y1, replace=True)
    # df_sample_y0 = data[data.iloc[:,-1] == 0].sample(n=n_y0, replace=True)
    # df_sample = pd.concat([df_sample_y1, df_sample_y0])
    #
    # # 对数据进行混洗
    # data = df_sample.sample(frac=1).reset_index(drop=True)
    X = data[:, :-1]
    y = data[:, -1]
    # scaler = StandardScaler()
    #
    # X= scaler.fit_transform(X)

    return X, y



def load_cidds(subset):
    data = pd.read_csv(os.path.join('../../dataset', 'cidds', f'{subset}.csv'), encoding='utf-8',
                       on_bad_lines='skip',
                       nrows=None)
    data.dropna(how='any', inplace=True)
    # indices = np.random.choice(len(data), size=20000, replace=False)
    # data = np.array(data)[indices]
    n_y1 = data[data['attack_type']==1].shape[0]# 假设 Y=1 的数量是 Y=0 的数量的10倍
    n_y0 = 29000

    # 随机抽取 Y=1 和 Y=0 对应数量的数据
    df_sample_y1 = data[data['attack_type']==1].sample(n=n_y1, replace=True)
    df_sample_y0 = data[data['attack_type']==0].sample(n=n_y0, replace=True)
    df_sample = pd.concat([df_sample_y1, df_sample_y0])

    # 对数据进行混洗
    data = df_sample.sample(frac=1).reset_index(drop=True)
    X = data.iloc[:, :-1].values
    y = data.iloc[:, -1].values
    return X, y








def load_KDD(subset, all_features=True):
    fr = pd.read_csv(os.path.join('../dataset', 'KDD', f'{subset}.data.txt'), encoding='utf-8', on_bad_lines = 'skip', nrows=None)

    indices = np.random.choice(len(fr), size=10000, replace=False)
    data = np.array(fr)[indices]



    last_column = data[:, -1]

    data[:, -1] = np.where(last_column == 'normal.', 0, 1)
    data_label = data[:,-1]
    data[:, 0:-1] = OrdinalEncoder().fit_transform(data[:, 0:-1])  # 特征的分类编码
    data = StandardScaler().fit_transform(data)  # 标准化：利用Sklearn库的StandardScaler对数据标准化


    line_nums = len(data)

    if all_features == True:
        data_feature = np.zeros((line_nums, 41))
        for i in range(line_nums):
            data_feature[i, :] = data[i][0:41]

    else:
        data_feature = np.zeros((line_nums, 10))
        for i in range(line_nums):
            feature = [3, 4, 5, 6, 8, 10, 13, 23, 24, 37]
            for j in feature:
                data_feature[i, feature.index(j)] = data[i][j]
            data_label[i] = int(data[i][-1])  # 标签


    return data_feature, data_label






# CICIDS

def encode_label_cicids(col: pd.Series):
    all_labels = list(set(col))
    a2l, l2a = {'BENIGN': 0}, {0: 'BENIGN'}
    all_labels.remove('BENIGN')
    for i, att in enumerate(all_labels):
        a2l[att] = i + 1
        l2a[i + 1] = att
    return a2l, l2a


def load_cicids(subset):
    df = pd.read_csv(os.path.join(CICIDS_DIR, CICIDS_DICT[subset] + '.pcap_ISCX.csv'))
    df = df[CICIDS_IP_COLS + CICIDS_FEAT_COLS + [CICIDS_LABEL_COL]]
    df.dropna(how='any', inplace=True)
    # df.drop(df[df.sum(axis=1) == np.inf].index, inplace=True)

    # only include two web servers' external comms
    # cond = df[' Source IP'].isin(CICIDS_SERVER_IPS) | df[' Destination IP'].isin(CICIDS_SERVER_IPS)
    cond = df[' Destination IP'].isin(CICIDS_SERVER_IPS)
    df = df[cond | (df[CICIDS_LABEL_COL] != 'BENIGN')]
    # cond = (df[' Source IP'].str.startswith('192.168.10') & df[' Destination IP'].str.startswith('192.168.10'))
    # df = df[(~cond) | (df[CICIDS_LABEL_COL] != 'BENIGN')]

    X = df[CICIDS_FEAT_COLS].to_numpy()
    a2l, l2a = encode_label_cicids(df[CICIDS_LABEL_COL])
    y = df[CICIDS_LABEL_COL].apply(lambda x: a2l[x]).to_numpy()

    return X, y


def load_cicids_custom(subset):
    # df = pd.read_csv(os.path.join(CUSTOM_DATA_DIR, 'CICIDS-2017', f'{subset}.csv'))
    df = pd.read_csv(os.path.join('../dataset', 'CICIDS-2017', f'{subset}.csv'))

    # print(df.shape)
    # only include two web servers' external comms

    cond = df['dest-ip'].isin(CICIDS_SERVER_IPS)
    df = df[cond | (df[CUSTOM_LABEL_COL] != 0)]

    X = df[CUSTOM_FEAT_COLS].to_numpy()
    y = df[CUSTOM_LABEL_COL].to_numpy()


    return X, y

def load_cicids_custom1(subset):
    # df = pd.read_csv(os.path.join(CUSTOM_DATA_DIR, 'CICIDS-2017', f'{subset}.csv'))
    df = pd.read_csv(os.path.join('../dataset', 'CICIDS-2017', f'{subset}.csv'))

    # only include two web servers' external comms
    cond = df['dest-ip'].isin(CICIDS_SERVER_IPS)
    df = df[cond | (df[CUSTOM_LABEL_COL] != 0)]

    X = df[CUSTOM_FEAT_COL].to_numpy()
    y = df[CUSTOM_LABEL_COL].to_numpy()

    return X, y


    
def load_Kyoto2016(subset):
    df = pd.read_csv(os.path.join('../dataset', 'Kyoto2016', f'{subset}.csv'))
    df.dropna(how='any', inplace=True)

    # only include two web servers' external comms
    print(df.shape)
    cond =df[df['18'].isin([1, -1])]
    indices = np.random.choice(len(cond), size=5000, replace=False)
    cond = cond.iloc[indices]


    X = cond.iloc[:, :14].values

    y = cond['18'].to_numpy()


    y[y==1]=0
    y[y == -1] = 1

    print(X.shape, y.shape)
    return  X,y
# CICIDS-improved
def load_cicids_improved(subset, **kwargs):
    subset = str(subset).lower()
    df = pd.read_csv(os.path.join(CICIDS_2_DIR, subset + '.csv'))
    try:
        feat_size = kwargs['feat_size']

        df = df[CICIDS_2_IP_COLS + CICIDS_2_FEAT_ALL_COLS[:feat_size] + [CICIDS_2_LABEL_COL, CICIDS_2_ATTEMPT_COL]]
        columns_to_extract = CICIDS_2_FEAT_ALL_COLS[:feat_size]
    except:
        df = df[CICIDS_2_IP_COLS + CICIDS_2_FEAT_COLS + [CICIDS_2_LABEL_COL, CICIDS_2_ATTEMPT_COL]]
        columns_to_extract = CICIDS_2_FEAT_COLS

    # filter attempted
    df = df[df[CICIDS_2_ATTEMPT_COL] == -1]

    # only include 3 servers' external comms
    cond = df[CICIDS_2_IP_COLS[1]].isin(CICIDS_2_SERVER_IPS)
    # cond = df[CICIDS_2_IP_COLS[0]].isin(CICIDS_2_CLIENT_IPS)
    df = df[cond | (df[CICIDS_2_LABEL_COL] != 'BENIGN')]

    X = df[columns_to_extract].to_numpy()
    y = df[CICIDS_2_LABEL_COL].apply(lambda x: x != 'BENIGN').astype('int').to_numpy()

    return X, y

# CSE-CICIDS-2018-improved
def load_cse_improved(subset, **kwargs):
    subset = str(subset).lower()
    df = pd.read_csv(os.path.join(CSE_DIR, subset + '.csv'))
    try:
        feat_size = kwargs['feat_size']
        # print(f'feat_size: {feat_size}')
        df = df[CICIDS_2_IP_COLS + CICIDS_2_FEAT_ALL_COLS[:feat_size] + [CICIDS_2_LABEL_COL, CICIDS_2_ATTEMPT_COL]]
        columns_to_extract = CICIDS_2_FEAT_ALL_COLS[:feat_size]
    except:
        df = df[CICIDS_2_IP_COLS + CICIDS_2_FEAT_COLS + [CICIDS_2_LABEL_COL, CICIDS_2_ATTEMPT_COL]]
        columns_to_extract = CICIDS_2_FEAT_COLS

    # filter attempted
    df = df[df[CICIDS_2_ATTEMPT_COL] == -1]

    # only include 2 servers' external comms
    cond = df[CICIDS_2_IP_COLS[1]].isin(CSE_SERVER_IPS)
    # cond = df[CICIDS_2_IP_COLS[0]].isin(CICIDS_2_CLIENT_IPS)
    df = df[cond | (df[CICIDS_2_LABEL_COL] != 'BENIGN')]

    X = df[columns_to_extract].to_numpy()
    y = df[CICIDS_2_LABEL_COL].apply(lambda x: x != 'BENIGN').astype('int').to_numpy()

    return X, y



def load_toniot_custom(subset):
    # df = pd.read_csv(os.path.join(CUSTOM_DATA_DIR, 'TON-IoT', f'{subset}.csv'))
    df = pd.read_csv(os.path.join('../dataset', 'TON-IoT', f'{subset}.csv'))

    # cond = df['src_ip'].isin(TONIOT_SERVER_IPS) | df['dst_ip'].isin(TONIOT_SERVER_IPS) 
    # df = df[cond | (df[CUSTOM_LABEL_COL] != 0)]
    # cond1 = (df['dur'] > 0)
    # cond2 = df['dst_ip'].apply(lambda x: IP(x) < IP('224.0.0.0/4'))
    # df = df[cond1 & cond2]
    # cond3 = df['dst_ip'].str.startswith('192.168.1')
    # df = df[~cond3 | (df[CUSTOM_LABEL_COL] == 0)]
    
    df_att = df[df['label'] == 1]

    df_list = [df_att]
    for f in os.listdir(os.path.join('../dataset', 'TON-IoT')):
        if f.startswith('normal'):
            df_norm = pd.read_csv(os.path.join('../dataset', 'TON-IoT', f))#CUSTOM_DATA_DIR
            cond = df_norm['src_ip'].isin(['3.122.49.24']) | df_norm['dst_ip'].isin(['3.122.49.24']) # TONIOT_IPS
            df_norm = df_norm[cond]
            df_list.append(df_norm)
    df = pd.concat(df_list)

    X = df[CUSTOM_FEAT_COLS].to_numpy()
    y = df[CUSTOM_LABEL_COL].to_numpy()

    return X, y